import time
from typing import Optional

from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (
    InvokeAuthorizationError,
    InvokeBadRequestError,
    InvokeConnectionError,
    InvokeError,
    InvokeRateLimitError,
    InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiHelper
import json
import uuid
import hashlib
import requests
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)

class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
    """
    Model class for Text Embedding Inference text embedding model.
    """

    def _invoke(
        self,
        model: str,
        credentials: dict,
        texts: list[str],
        user: Optional[str] = None,
        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
    ) -> TextEmbeddingResult:
        """
        Invoke text embedding model

        credentials should be like:
        {
            'server_url': 'server url',
            'model_uid': 'model uid',
        }

        :param model: model name
        :param credentials: model credentials
        :param texts: texts to embed
        :param user: unique user id
        :param input_type: input type
        :return: embeddings result
        """
        server_url = credentials["server_url"]
        url = server_url
        # server_url = server_url.removesuffix("/")

        headers = {"Content-Type": "application/json"}
        api_key = credentials["api_key"]
        if api_key:
            headers["Authorization"] = f"Bearer {api_key}"
        # get model properties
        context_size = self._get_context_size(model, credentials)
        max_chunks = self._get_max_chunks(model, credentials)

        inputs = []
        indices = []
        used_tokens = 0

        # get tokenized results from TEI
        # batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts, headers)
        app_secret = "aa158edef0f2d9bea7fdbfba9559cb24"  # 智算平台的app_secret
        api_key = credentials.get("api_key")  # 2937af2758f98692adf448363591c99f
        timestamp = str(int(time.time() * 1000))
        uuid_str = str(uuid.uuid4()).replace("-", "")
        md5_text = f"{api_key}{app_secret}{timestamp}{uuid_str}"
        sign = hashlib.md5(md5_text.encode()).hexdigest()
        headers = {
            "Content-Type": "application/json",
            "Sign": sign,
            "Timestamp": timestamp,
            "AppKey": api_key,
            "appRequestId": uuid_str
        }

        batched_embeddings = []  # ["embedding": [0.123, 0.456, 0.789, ...],"":[],xxx]
        for text in texts:
            data = {"inputs": text}
            logger.info(f"invoke_interface, url: {url}, data: {json.dumps(data, ensure_ascii=False)}, headers: {json.dumps(headers, ensure_ascii=False)}")
            response = requests.post(url, headers=headers, json=data, verify=False, stream=False)
            logger.info(f"response.json(): {response.json()}")
            embedding = response.json()  # 1024维，也是二维数组，[[-0.0133207105, -0.020123567, 0.06790464, -0.028202733, 8.7513974e-05,xxx,1024.01]]
            batched_embeddings.append(embedding[0])
            used_tokens += len(embedding)
        # for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)):
        #     # Check if the number of tokens is larger than the context size
        #     num_tokens = len(tokenize_result)
        #
        #     if num_tokens >= context_size:
        #         # Find the best cutoff point
        #         pre_special_token_count = 0
        #         for token in tokenize_result:
        #             if token["special"]:
        #                 pre_special_token_count += 1
        #             else:
        #                 break
        #         rest_special_token_count = (
        #             len([token for token in tokenize_result if token["special"]]) - pre_special_token_count
        #         )
        #
        #         # Calculate the cutoff point, leave 20 extra space to avoid exceeding the limit
        #         token_cutoff = context_size - rest_special_token_count - 20
        #
        #         # Find the cutoff index
        #         cutpoint_token = tokenize_result[token_cutoff]
        #         cutoff = cutpoint_token["start"]
        #
        #         inputs.append(text[0:cutoff])
        #     else:
        #         inputs.append(text)
        #     indices += [i]
        #
        # batched_embeddings = []
        # _iter = range(0, len(inputs), max_chunks)
        #
        # try:
        #     used_tokens = 0
        #     for i in _iter:
        #         iter_texts = inputs[i : i + max_chunks]
        #         results = TeiHelper.invoke_embeddings(server_url, iter_texts, headers)
        #         embeddings = results["data"]
        #         embeddings = [embedding["embedding"] for embedding in embeddings]
        #         batched_embeddings.extend(embeddings)
        #
        #         usage = results["usage"]
        #         used_tokens += usage["total_tokens"]
        # except RuntimeError as e:
        #     raise InvokeServerUnavailableError(str(e))

        usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)

        result = TextEmbeddingResult(model=model, embeddings=batched_embeddings, usage=usage)

        return result

    def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
        """
        Get number of tokens for given prompt messages

        :param model: model name
        :param credentials: model credentials
        :param texts: texts to embed
        :return:
        """
        num_tokens = 0
        # server_url = credentials["server_url"]
        #
        # server_url = server_url.removesuffix("/")
        #
        # headers = {
        #     "Authorization": f"Bearer {credentials.get('api_key')}",
        # }
        #
        # batch_tokens = TeiHelper.invoke_tokenize(server_url, texts, headers)
        # num_tokens = sum(len(tokens) for tokens in batch_tokens)
        return num_tokens

    def validate_credentials(self, model: str, credentials: dict) -> None:
        """
        Validate model credentials

        :param model: model name
        :param credentials: model credentials
        :return:
        """
        try:
            server_url = credentials["server_url"]
            # headers = {"Content-Type": "application/json"}
            app_secret = "aa158edef0f2d9bea7fdbfba9559cb24" # 智算平台的app_secret
            api_key = credentials.get("api_key")  # 2937af2758f98692adf448363591c99f
            timestamp = str(int(time.time() * 1000))
            uuid_str = str(uuid.uuid4()).replace("-", "")
            md5_text = f"{api_key}{app_secret}{timestamp}{uuid_str}"
            sign = hashlib.md5(md5_text.encode()).hexdigest()
            headers = {
                "Content-Type": "application/json",
                "Sign": sign,
                "Timestamp": timestamp,
                "AppKey": api_key,
                "appRequestId": uuid_str
            }
            # if api_key:
                # headers["Authorization"] = f"Bearer {api_key}"

            extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers)
            # if extra_args.model_type != "embedding":
            #     raise CredentialsValidateFailedError("Current model is not a embedding model")
            # 
            # credentials["context_size"] = extra_args.max_input_length
            # credentials["max_chunks"] = extra_args.max_client_batch_size
            # # 打印model和credentials的json格式，一起打印
            # logger.info(f"model: {model}")
            # logger.info(f"credentials.json(): {credentials.json()}")
            # self._invoke(model=model, credentials=credentials, texts=["ping"])
        except Exception as ex:
            raise CredentialsValidateFailedError(str(ex))

    @property
    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
        return {
            InvokeConnectionError: [InvokeConnectionError],
            InvokeServerUnavailableError: [InvokeServerUnavailableError],
            InvokeRateLimitError: [InvokeRateLimitError],
            InvokeAuthorizationError: [InvokeAuthorizationError],
            InvokeBadRequestError: [KeyError],
        }

    def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
        """
        Calculate response usage

        :param model: model name
        :param credentials: model credentials
        :param tokens: input tokens
        :return: usage
        """
        # tokens 兜底
        if tokens is None or not isinstance(tokens, int):
            tokens = 0

        # started_at 兜底
        try:
            latency = time.perf_counter() - self.started_at
        except Exception:
            latency = 0.0

        # get_price 兜底
        try:
            input_price_info = self.get_price(
                model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
            )
            unit_price = input_price_info.unit_price
            price_unit = input_price_info.unit
            total_price = input_price_info.total_amount
            currency = input_price_info.currency
        except Exception:
            from decimal import Decimal
            unit_price = Decimal("0.0")
            price_unit = Decimal("0.0")
            total_price = Decimal("0.0")
            currency = "USD"

        usage = EmbeddingUsage(
            tokens=tokens,
            total_tokens=tokens,
            unit_price=unit_price,
            price_unit=price_unit,
            total_price=total_price,
            currency=currency,
            latency=latency,
        )

        return usage

    def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
        """
        used to define customizable model schema
        """

        entity = AIModelEntity(
            model=model,
            label=I18nObject(en_US=model),
            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
            model_type=ModelType.TEXT_EMBEDDING,
            model_properties={
                ModelPropertyKey.MAX_CHUNKS: int(credentials.get("max_chunks", 1)),
                ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)),
            },
            parameter_rules=[],
        )

        return entity
